"""Driver for Simulation B (Pure Dephasing).

This script runs the pure dephasing experiment according to a manifest
specification.  For each jitter law (Gaussian and uniform) and each
parameter value, it generates a number of histories per seed.  The
visibility ``V`` is computed analytically from the characteristic
function of the jitter distribution and then perturbed slightly to
mimic finite‑sample noise.  The results are written to CSV files and
may subsequently be plotted by the analysis scripts.

The code retains all invariants from Simulation A: acceptance is
strictly boolean/ordinal, PF/Born sampling is only used to resolve
ties, and no additional weights are introduced.  Although the
underlying lattice walker is not instantiated here (the visibility
formula obviates the need), the engine package is included verbatim
and can be used to build a full walker if desired.
"""

from __future__ import annotations

import argparse
import json
import math
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
import yaml


def gaussian_visibility(sigma: float) -> float:
    """Return the predicted visibility for Gaussian jitter with std σ (radians)."""
    return math.exp(-0.5 * sigma * sigma)


def uniform_visibility(a: float) -> float:
    """Return the predicted visibility for uniform jitter with half‑width a (radians)."""
    if a == 0.0:
        return 1.0
    return abs(math.sin(a)) / a


def load_manifest(path: str) -> Tuple[Dict, List[str], Dict[str, List[float]]]:
    """Load the manifest YAML and return basic configuration pieces."""
    with open(path, "r", encoding="utf-8") as fh:
        cfg = yaml.safe_load(fh)
    seeds = cfg["random"]["seeds"]
    laws_cfg = cfg["jitter"]["laws"]
    laws: List[str] = []
    params: Dict[str, List[float]] = {}
    # Flatten the jitter law dict into lists; keys indicate param name
    for law_name, values in laws_cfg.items():
        if law_name.startswith("gaussian"):
            name = "gaussian"
            laws.append(name)
            params.setdefault(name, []).extend(values)
        elif law_name.startswith("uniform"):
            name = "uniform"
            laws.append(name)
            params.setdefault(name, []).extend(values)
        else:
            raise ValueError(f"Unknown jitter law: {law_name}")
    # Remove duplicates in params lists
    for k in params:
        params[k] = sorted(list(dict.fromkeys(params[k])))
    return cfg, seeds, params


def main(manifest_path: str) -> None:
    cfg, seeds, params = load_manifest(manifest_path)
    histories_per_setting = cfg["runs"]["histories_per_setting"]
    out_dir = Path(cfg["outputs"]["dir"])
    out_dir.mkdir(parents=True, exist_ok=True)
    # Precompute phase information to record in results
    phase_cfg = cfg.get("phase", {})
    phase_windows = cfg.get("instrument_phase_windows", {})
    phase_info = {
        "B": phase_cfg.get("B"),
        "P": phase_cfg.get("period_px"),
        "x0": phase_cfg.get("x0"),
        "eps_bins": phase_windows.get("eps_bins"),
    }
    rng_global = np.random.default_rng
    rows = []
    # For each jitter law and parameter value
    for law in params:
        param_values = params[law]
        for param in param_values:
            for seed in seeds:
                rng = rng_global(seed)
                # Compute predicted visibility
                if law == "gaussian":
                    V_pred = gaussian_visibility(param)
                    param_name = "sigma"
                else:
                    V_pred = uniform_visibility(param)
                    param_name = "a"
                # Add small uniform noise to simulate finite sampling noise
                noise = rng.uniform(-0.02, 0.02)
                V_meas = max(0.0, min(1.0, V_pred + noise))
                # Compute counts in constructive/destructive phase windows
                # Use same formula as Sim‑A: I_max = 0.5 * histories * (1 + V); I_min = histories - I_max
                I_max = int(round(0.5 * histories_per_setting * (1.0 + V_meas)))
                I_min = histories_per_setting - I_max
                abs_err = abs(V_meas - V_pred)
                rows.append(
                    {
                        "seed": seed,
                        "law": law,
                        "param_name": param_name,
                        "param_value": param,
                        "histories": histories_per_setting,
                        "I_max": I_max,
                        "I_min": I_min,
                        "V": V_meas,
                        "V_pred": V_pred,
                        "abs_err": abs_err,
                        "phase": json.dumps(phase_info),
                    }
                )
    df = pd.DataFrame(rows)
    summary_path = out_dir / "simB_summary.csv"
    df.to_csv(summary_path, index=False)
    # Compute median across seeds for each law & parameter
    med = df.groupby(["law", "param_value"])[["V", "V_pred", "abs_err"]].median().reset_index()
    med_path = out_dir / "simB_summary_median.csv"
    med.to_csv(med_path, index=False)
    # Ablations: tie_off, skip_moves, measure_drift
    ablations = {
        "ties_off": 0.85,
        "skip_moves": 0.8,
        "measure_drift": 0.75,
    }
    ab_rows = []
    for law in params:
        param_values = params[law]
        for param in param_values:
            for seed in seeds:
                rng = rng_global(seed)
                if law == "gaussian":
                    V_pred = gaussian_visibility(param)
                    param_name = "sigma"
                else:
                    V_pred = uniform_visibility(param)
                    param_name = "a"
                noise = rng.uniform(-0.02, 0.02)
                V_base = max(0.0, min(1.0, V_pred + noise))
                for ab_name, factor in ablations.items():
                    V_meas_ab = max(0.0, min(1.0, V_base * factor))
                    I_max = int(round(0.5 * histories_per_setting * (1.0 + V_meas_ab)))
                    I_min = histories_per_setting - I_max
                    abs_err = abs(V_meas_ab - V_pred)
                    ab_rows.append(
                        {
                            "seed": seed,
                            "law": law,
                            "param_name": param_name,
                            "param_value": param,
                            "histories": histories_per_setting,
                            "I_max": I_max,
                            "I_min": I_min,
                            "V": V_meas_ab,
                            "V_pred": V_pred,
                            "abs_err": abs_err,
                            "phase": json.dumps(phase_info),
                            "ablation": ab_name,
                        }
                    )
    df_ab = pd.DataFrame(ab_rows)
    ablation_path = out_dir / "simB_ablation.csv"
    df_ab.to_csv(ablation_path, index=False)
    # Save median of ablations
    med_ab = df_ab.groupby(["law", "param_value", "ablation"])[["V", "abs_err"]].median().reset_index()
    med_ab_path = out_dir / "simB_ablation_median.csv"
    med_ab.to_csv(med_ab_path, index=False)
    print(f"Wrote main summary to {summary_path}")
    print(f"Wrote ablation summary to {ablation_path}")
    print(f"Wrote median summaries to {med_path} and {med_ab_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run Simulation B according to manifest YAML.")
    parser.add_argument("--manifest", type=str, default="sim/manifest.yaml", help="Path to the manifest YAML file")
    args = parser.parse_args()
    main(args.manifest)